/*
 * Copyright 2012-2016 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.springframework.security.oauth2.provider.token.store.redis;

import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.data.redis.connection.jedis.JedisConnection;
import org.springframework.data.redis.connection.jedis.JedisConnectionFactory;
import org.springframework.security.authentication.TestingAuthenticationToken;
import org.springframework.security.oauth2.common.DefaultOAuth2AccessToken;
import org.springframework.security.oauth2.common.DefaultOAuth2RefreshToken;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.OAuth2RefreshToken;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.OAuth2Request;
import org.springframework.security.oauth2.provider.RequestTokenFactory;

import java.util.*;

import static org.mockito.Matchers.any;
import static org.mockito.Mockito.*;

/**
 *
 * @author Joe Grandja
 *
 */
public class RedisTokenStoreMockTests {
	private JedisConnectionFactory connectionFactory;
	private JedisConnection connection;
	private RedisTokenStore tokenStore;
	private OAuth2Request request;
	private TestingAuthenticationToken authentication;

	@Before
	public void setUp() throws Exception {
		connectionFactory = mock(JedisConnectionFactory.class);
		connection = mock(JedisConnection.class);
		when(connectionFactory.getConnection()).thenReturn(connection);
		tokenStore = new RedisTokenStore(connectionFactory);
		Set<String> scope = new LinkedHashSet<String>(Arrays.asList("read", "write", "read-write"));
		request = RequestTokenFactory.createOAuth2Request("clientId", false, scope);
		authentication = new TestingAuthenticationToken("user", "password");
	}

	// gh-572
	@Test
	public void storeRefreshTokenRemoveRefreshTokenVerifyKeysRemoved() {
		OAuth2RefreshToken oauth2RefreshToken = new DefaultOAuth2RefreshToken("refresh-token-" + UUID.randomUUID());
		OAuth2Authentication oauth2Authentication = new OAuth2Authentication(request, authentication);
		tokenStore.storeRefreshToken(oauth2RefreshToken, oauth2Authentication);

		ArgumentCaptor<byte[]> keyArgs = ArgumentCaptor.forClass(byte[].class);
		verify(connection, times(2)).set(keyArgs.capture(), any(byte[].class));

		List<Object> result = new ArrayList<Object>();
		result.add(Long.valueOf(1));
		result.add(Long.valueOf(1));
		result.add(new byte[] {42});
		result.add(Long.valueOf(1));
		when(connection.closePipeline()).thenReturn(result);

		tokenStore.removeRefreshToken(oauth2RefreshToken);

		for (byte[] key : keyArgs.getAllValues()) {
			verify(connection).del(key);
		}
	}

	// gh-572
	@Test
	public void storeAccessTokenWithoutRefreshTokenRemoveAccessTokenVerifyKeysRemoved() {
		OAuth2AccessToken oauth2AccessToken = new DefaultOAuth2AccessToken("access-token-" + UUID.randomUUID());
		OAuth2Authentication oauth2Authentication = new OAuth2Authentication(request, authentication);

		List<Object> results = Arrays.<Object>asList("access-token".getBytes(), "authentication".getBytes());
		when(connection.closePipeline()).thenReturn(results);

		RedisTokenStoreSerializationStrategy serializationStrategy = new JdkSerializationStrategy();
		serializationStrategy = spy(serializationStrategy);
		when(serializationStrategy.deserialize(any(byte[].class), eq(OAuth2Authentication.class))).thenReturn(oauth2Authentication);
		tokenStore.setSerializationStrategy(serializationStrategy);

		tokenStore.storeAccessToken(oauth2AccessToken, oauth2Authentication);

		ArgumentCaptor<byte[]> setKeyArgs = ArgumentCaptor.forClass(byte[].class);
		verify(connection, times(3)).set(setKeyArgs.capture(), any(byte[].class));

		ArgumentCaptor<byte[]> sAddKeyArgs = ArgumentCaptor.forClass(byte[].class);
		verify(connection, times(2)).sAdd(sAddKeyArgs.capture(), any(byte[].class));

		tokenStore.removeAccessToken(oauth2AccessToken);

		for (byte[] key : setKeyArgs.getAllValues()) {
			verify(connection).del(key);
		}
		for (byte[] key : sAddKeyArgs.getAllValues()) {
			verify(connection).sRem(eq(key), any(byte[].class));
		}
	}

}